import streamlit as st
import pandas as pd
import numpy as np
import requests
import re
import json
import csv
import joblib
import torch
import pgeocode
import folium
from datetime import datetime
from pathlib import Path
from branca.element import IFrame
from folium.plugins import MarkerCluster
from streamlit_folium import folium_static
from gensim.models import KeyedVectors
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from reasoner.reasoner import Reasoner

from openai import OpenAI




VALID_MODEL_NAME = "deepseek-r1"  

client = OpenAI(
    api_key="",
    base_url="https://api.lkeap.cloud.tencent.com/v1",
)

# Load knowledge base
def load_knowledge_base(csv_path):
    try:
        with open(csv_path, 'r', encoding='utf-8-sig') as file:
            reader = csv.DictReader(file)
            lines = [f"{row['symptom'].strip()} -> {row['code'].strip().replace('.', '')}" for row in reader]
            return "\n".join(lines)
    except Exception as e:
        print(f"Error loading KB: {e}")
        return ""

knowledge_base_path = "ICD9_symptom_mapping.csv"
knowledge_base = load_knowledge_base(knowledge_base_path)

class BioBertEntityExtractor:
    def __init__(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForTokenClassification.from_pretrained("biobert_model")
        self.tokenizer = AutoTokenizer.from_pretrained("biobert_model")
        self.pipeline = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if torch.cuda.is_available() else -1, aggregation_strategy="simple")

    def extract_entities(self, text):
        entities = self.pipeline(text)
        filtered = []
        for ent in entities:
            group = ent["entity_group"].replace("B-", "").replace("I-", "")
            if group in ["Disease", "Chemical"]:
                filtered.append({"entity_group": "Drug" if group == "Chemical" else "Disease", "word": ent["word"].replace("##", "").lower().strip()})
        return filtered

def get_promote(text, entities):
    return f"""
# Role Setting
You are a medical coding assistant that strictly follows a whitelist and can only use the following preset codes:

{knowledge_base}

# Input Format
Input as a series of users' natural language:

{text}

# Medical Knowledge Context
You have identified the following clinically relevant entities:

{entities}

# Task Requirements
1. Identify the severity of these entities(generally in English) from the input. Ignore the parts of the patient's language that are unrelated to the symptoms. If the user's input is unrelated to the medical field, it is possible to have a conversation with the user without the set personality.
2. Every entity must match only the standard terms in this 【strict list】. This entity can match the similar meaning of an element in a strict list, or it can match the symptoms generated by an element in a strict list.
3. Sort the matches by severity(descending order).


# Output Format
{{
    "diagnoses": [
        {{
            "standard_term": "Standardized medical term 1",
            "match_status": "Matched/Unknown",  # Matching status
            "severity": 10,            # Severity, which is 0 if Matching status is Unknown
            "icd9_code": ""                    # Leave empty if unmatched
        }},
        {{
            "standard_term": "Standardized medical term 2",
            "match_status": "Matched/Unknown",
            "severity": 4,
            "icd9_code": ""
        }}
    ]
}}

# Example
Input: My hands and feet feel constantly numb, like I’m wearing invisible gloves and socks all the time.
Output:
{{
    "diagnoses": [
        {{"standard_term": "Shigella boydii", "match_status": "Matched", "severity": 2, "icd9_code": "42"}},
        {{"standard_term": "Poisoning by other diuretics", "match_status": "Matched", "severity": 1, "icd9_code": "9744"}}
    ]
}}
"""

def get_promote2(pre_res, intermediate):
    return f"""
# Role Setting
You are a senior medical translator who needs to translate professional biomarker information into natural language explanations that are easy for patients to understand.
Previously, through analysis, you obtained some possible symptoms from the patient's description:
{pre_res}
The format of this data is:
[('4254', <TruthValue: %1.00;0.99% (k=1)>),
 ('42789', <TruthValue: %1.00;0.99% (k=1)>),
 ('3320', <TruthValue: %1.00;0.98% (k=1)>),
 ('3591', <TruthValue: %1.00;0.96% (k=1)>),
 ('4271', <TruthValue: %1.00;0.96% (k=1)>)]
In this list, the first element of each tuple is the ICD9 number of the disease, and the mapping rule is based on {knowledge_base}. The second one is the true value. The format of the truth value is% f; c%， The higher the value of c, the greater the degree of certainty.

Now, you have obtained an intermediate result from the backend, which is a list of genes and proteins related to these diseases:
{intermediate}
The format of this data is:
[('GENE:32600', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29744', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32213', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:33672', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29386', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:29404', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:30994', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32523', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:32532', <TruthValue: %1.00;0.90% (k=1)>),
 ('GENE:24563', <TruthValue: %1.00;0.90% (k=1)>)]
In this list, the first element of each tuple is the gene number, and the second element is also the truth value, in the format of% f; c%， The higher the value of c, the greater the degree of certainty.

With these two pieces of information, you need to:
1. Explain to the patient what diseases they may have. 
2. If the user further inquires about the details of the illness or the cause of the illness, then explain to the patient why these symptoms occur, which genes are related to them in the body, and name up to five related genes. If there is no further inquiry, the user may not be presented with genetic related reasoning.
3. Replace professional medical terminology with more common vocabulary.
4. Finally, provide 1-3 specific suggestions and a disclaimer. It is also necessary to consult a professional doctor at the hospital.
"""

def get_explaination(text, entities):
    full_prompt = get_promote2(text, entities)
    
    try:
        response = client.chat.completions.create(
            model="deepseek-r1",
            messages=[{"role": "user", "content": full_prompt}],
            temperature=0.5,  
            max_tokens=1000,
            stop=["---"], 
            stream=True
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error：{str(e)}"
    
def process_protein_list(protein_list):
    processed = []
    for item in protein_list:
        if isinstance(item, str) and item.startswith('PROTEIN:'):
            processed.append(item)
        else:
            try:
                num = int(item)
                processed.append(f'PROTEIN:{num}')
            except:
                processed.append(f'PROTEIN:{item}')
    return processed
    
def procedure2gene(procedures):
    # === Load Pretrained Models ===
    model = KeyedVectors.load_word2vec_format("node2vec_embeddings.txt")
    clf = joblib.load("rf_model.pkl")


    # === Collect All Possible Proteins from Vocabulary (excluding procedures) ===
    all_nodes = list(model.key_to_index)
    candidate_proteins = [node for node in all_nodes if node not in procedures]

    # === Predict Proteins for Each Procedure ===
    predicted_links = []

    for proc in procedures:
        proc_links = []
        for prot in candidate_proteins:
            if proc in model and prot in model:
                features = np.concatenate((model[proc], model[prot])).reshape(1, -1)
                pred = clf.predict(features)[0]
                if pred == 1:
                    proc_links.append(prot)
        predicted_links.append({"procedure": proc, "protein": proc_links})

    # === Format and Display Result Table ===
    df_result = pd.DataFrame(predicted_links)
    df_result['protein'] = df_result['protein'].apply(process_protein_list)


    
    reasoner = Reasoner(5)
    all_results = []
    all_intermediates = []
    for protein_list in df_result['protein']:
        result, intermediate_results = reasoner.reason(protein_list)
        all_results.extend(result)
        all_intermediates.extend(intermediate_results)
    
    def get_c_value(truth_value):

        parts = str(truth_value).split(";")
        c_str = parts[1].split("%")[0]  
        return float(c_str)
    
    sorted_results = sorted(
        all_results,
        key=lambda x: get_c_value(x[1]),
        reverse=True
    )
    sorted_intermediates = sorted(
        all_intermediates,
        key=lambda x: get_c_value(x[1]),
        reverse=True
    )

    return sorted_results, sorted_intermediates


def get_prompt(text, entities):
    return f"""
# Role Setting
You are a medical coding assistant. Use only this whitelist:

{knowledge_base}

# User Input
{text}

# Extracted Entities:
{entities}

# Task:
1. Identify severity.
2. Match to whitelist terms.
3. Sort by severity.

# Format:
{{
  "diagnoses": [
    {{"standard_term": "...", "match_status": "Matched/Unknown", "severity": X, "icd9_code": ""}},
    ...
  ]
}}
"""

def extract_icd9_deepseek(text, entities):
    prompt = get_prompt(text, entities)
    try:
        response = client.chat.completions.create(
            model=VALID_MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            response_format="json_object"
        )
        return json.loads(response.choices[0].message.content)
    except Exception as e:
        return {"diagnoses": [], "error": str(e)}

# Streamlit UI
st.markdown("# Hi, welcome, I am ReC! Please let me know what can I do for you?")
st.markdown("### You can describe your symptoms in the following chatbox:")


placeholder_text = (
    "I’ve been completely knocked out—so drowsy I can’t keep my eyes open, and my mind feels scrambled..."
)


chars_per_line = 100
line_count = len(placeholder_text) // chars_per_line + 1
estimated_height = line_count * 25  


user_query = st.text_area(
    "Enter your clinical question:",
    placeholder=placeholder_text,
    height=estimated_height,
    key="user_query"
)


submit_button = st.button("Submit")


if submit_button and user_query:
    st.info("Extracting medical entities and reasoning with ReC...")

    # Step 1: Entity extraction
    extractor = BioBertEntityExtractor()
    entities = extractor.extract_entities(user_query)

    # Step 2: ICD-9 matching
    result = extract_icd9_deepseek(user_query, entities)
    output = result
    icd9_codes = [entry["icd9_code"] for entry in output["diagnoses"]]

    # Step 3: Procedure-to-gene reasoning
    if icd9_codes:
        sorted_results, sorted_intermediates = procedure2gene(icd9_codes)
    else:
        sorted_results, sorted_intermediates = [], []

    # Step 4: Prompt construction
    full_prompt = get_promote2(sorted_results, sorted_intermediates)

    # Step 5: Stream output in real time
    st.markdown("**AI-generated explanation:**")
    explanation_box = st.empty()
    generated_text = ""

    try:
        response = client.chat.completions.create(
            model="deepseek-r1", 
            messages=[{"role": "user", "content": full_prompt}],
            temperature=0.5,
            max_tokens=1000,
            stop=["---"],
            stream=True
        )

        for chunk in response:
            if hasattr(chunk.choices[0].delta, "content"):
                generated_text += chunk.choices[0].delta.content
                explanation_box.markdown(generated_text)

    except Exception as e:
        st.error(f"Error：{str(e)}")
  